# -*- coding: UTF-8 -*-

import pickle

import torch
import torch.nn as nn
import torch.nn.functional as F


class Linear3(nn.Module):
    def __init__(self, is_identity=False):
        super(Linear3, self).__init__()
        dim = 10
        self.fc0_weight = nn.Parameter(torch.zeros((dim, dim), dtype=torch.float))
        self.fc1_weight = nn.Parameter(torch.zeros((dim, dim), dtype=torch.float))
        self.fc2_weight = nn.Parameter(torch.zeros((dim, dim), dtype=torch.float))

        if is_identity:
            nn.init.eye_(self.fc0_weight)
            nn.init.eye_(self.fc1_weight)
            nn.init.eye_(self.fc2_weight)
        else:
            nn.init.xavier_normal_(self.fc0_weight)
            nn.init.xavier_normal_(self.fc1_weight)
            nn.init.xavier_normal_(self.fc2_weight)


    def forward(self, x):
        x = F.linear(x, self.fc0_weight)
        x = F.linear(x, self.fc1_weight)
        x = F.linear(x, self.fc2_weight)
        return x


if __name__ == '__main__':
    data_path = "./data/data.pkl"
    with open(data_path, "rb") as f:
        data = pickle.load(f)

    x, y = data["x"], data["y"]
    data_len = len(x)
    train_data_len = int(data_len / 2.)
    x_train, x_test = x[:train_data_len], x[train_data_len:]
    y_train, y_test = y[:train_data_len], y[train_data_len:]

    dataset = torch.utils.data.TensorDataset(x_train, y_train)
    dataset_loader = torch.utils.data.DataLoader(dataset, batch_size=500, shuffle=True, num_workers=8, pin_memory=True)

    epochs = 40

    # choose initialization with/without identity matrix
    model = Linear3(is_identity=True)
    # model = Linear3(is_identity=False)

    # choose SGD with/without the momentum
    optimizer = torch.optim.SGD(model.parameters(), lr=5e-1, momentum=0.9)
    # optimizer = torch.optim.SGD(model.parameters(), lr=5e-1)

    model.cuda()
    model.train()
    for epoch in range(epochs):
        for input, target in dataset_loader:
            input, target = input.cuda(), target.cuda()
            pred = model(input)
            loss = F.mse_loss(pred, target)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

    model.eval()
    x_test = x_test.cuda()
    y_test = y_test.cuda()
    with torch.no_grad():
        pred = model(x_test)
        loss = F.mse_loss(pred, y_test)
    print("Test Loss:", loss.cpu().numpy())

